# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import json
import logging
import re
from json import JSONDecodeError
from typing import Any

from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import BaseChatModel
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.tool import ToolMessage
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.graph import StateGraph
from langgraph.graph.state import CompiledStateGraph
from pydantic import BaseModel
from pydantic import Field

from nat.agent.base import AGENT_CALL_LOG_MESSAGE
from nat.agent.base import AGENT_LOG_PREFIX
from nat.agent.base import INPUT_SCHEMA_MESSAGE
from nat.agent.base import NO_INPUT_ERROR_MESSAGE
from nat.agent.base import TOOL_NOT_FOUND_ERROR_MESSAGE
from nat.agent.base import AgentDecision
from nat.agent.base import BaseAgent

logger = logging.getLogger(__name__)


class ReWOOEvidence(BaseModel):
    placeholder: str
    tool: str
    tool_input: Any


class ReWOOPlanStep(BaseModel):
    plan: str
    evidence: ReWOOEvidence


class ReWOOGraphState(BaseModel):
    """State schema for the ReWOO Agent Graph"""
    messages: list[BaseMessage] = Field(default_factory=list)  # input and output of the ReWOO Agent
    task: HumanMessage = Field(default_factory=lambda: HumanMessage(content=""))  # the task provided by user
    plan: AIMessage = Field(
        default_factory=lambda: AIMessage(content=""))  # the plan generated by the planner to solve the task
    steps: AIMessage = Field(
        default_factory=lambda: AIMessage(content=""))  # the steps to solve the task, parsed from the plan
    # New fields for parallel execution support
    evidence_map: dict[str, ReWOOPlanStep] = Field(default_factory=dict)  # mapping from placeholders to step info
    execution_levels: list[list[str]] = Field(default_factory=list)  # levels for parallel execution
    current_level: int = Field(default=0)  # current execution level
    intermediate_results: dict[str, ToolMessage] = Field(default_factory=dict)  # the intermediate results of each step
    result: AIMessage = Field(
        default_factory=lambda: AIMessage(content=""))  # the final result of the task, generated by the solver


class ReWOOAgentGraph(BaseAgent):
    """Configurable ReWOO Agent.

    Args:
        detailed_logs: Toggles logging of inputs, outputs, and intermediate steps.
    """

    def __init__(self,
                 llm: BaseChatModel,
                 planner_prompt: ChatPromptTemplate,
                 solver_prompt: ChatPromptTemplate,
                 tools: list[BaseTool],
                 use_tool_schema: bool = True,
                 callbacks: list[AsyncCallbackHandler] | None = None,
                 detailed_logs: bool = False,
                 log_response_max_chars: int = 1000,
                 tool_call_max_retries: int = 3,
                 raise_tool_call_error: bool = True):
        super().__init__(llm=llm,
                         tools=tools,
                         callbacks=callbacks,
                         detailed_logs=detailed_logs,
                         log_response_max_chars=log_response_max_chars)

        logger.debug(
            "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
            AGENT_LOG_PREFIX)

        def describe_tool(tool: BaseTool) -> str:
            description = f"{tool.name}: {tool.description}"
            if use_tool_schema:
                description += f". {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
            return description

        tool_names = ",".join(tool.name for tool in tools)
        tool_names_and_descriptions = "\n".join(describe_tool(tool) for tool in tools)

        self.planner_prompt = planner_prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
        self.solver_prompt = solver_prompt
        self.tools_dict = {tool.name: tool for tool in tools}
        self.tool_call_max_retries = tool_call_max_retries
        self.raise_tool_call_error = raise_tool_call_error

        logger.debug("%s Initialized ReWOO Agent Graph", AGENT_LOG_PREFIX)

    def _get_tool(self, tool_name: str):
        try:
            return self.tools_dict.get(tool_name)
        except Exception as ex:
            logger.error("%s Unable to find tool with the name %s\n%s", AGENT_LOG_PREFIX, tool_name, ex)
            raise

    @staticmethod
    def _get_current_level_status(state: ReWOOGraphState) -> tuple[int, bool]:
        """
        Get the current execution level and whether it's complete.

        Args:
            state: The ReWOO graph state.

        Returns:
            tuple of (current_level, is_complete). Level -1 means all execution is complete.
        """
        if not state.execution_levels:
            return -1, True

        current_level = state.current_level

        # Check if we've completed all levels
        if current_level >= len(state.execution_levels):
            return -1, True

        # Check if current level is complete
        current_level_placeholders = state.execution_levels[current_level]
        level_complete = all(placeholder in state.intermediate_results for placeholder in current_level_placeholders)

        return current_level, level_complete

    @staticmethod
    def _parse_planner_output(planner_output: str) -> list[ReWOOPlanStep]:

        try:
            return [ReWOOPlanStep(**step) for step in json.loads(planner_output)]
        except Exception as ex:
            raise ValueError(f"The output of planner is invalid JSON format: {planner_output}") from ex

    @staticmethod
    def _parse_planner_dependencies(steps: list[ReWOOPlanStep]) -> tuple[dict[str, ReWOOPlanStep], list[list[str]]]:
        """
        Parse planner steps to identify dependencies and create execution levels for parallel processing.
        This creates a dependency map and identifies which evidence placeholders can be executed in parallel.

        Args:
            steps: list of plan steps from the planner.

        Returns:
            A mapping from evidence placeholders to step info and execution levels for parallel processing.
        """
        # First pass: collect all evidence placeholders and their info
        evidences: dict[str, ReWOOPlanStep] = {
            step.evidence.placeholder: step
            for step in steps if step.evidence and step.evidence.placeholder
        }

        # Second pass: find dependencies now that we have all placeholders
        dependencies = {
            step.evidence.placeholder: [
                var for var in re.findall(r"#E\d+", str(step.evidence.tool_input))
                if var in evidences and var != step.evidence.placeholder
            ]
            for step in steps if step.evidence and step.evidence.placeholder
        }

        # Create execution levels using topological sort
        levels: list[list[str]] = []
        remaining = dict(dependencies)

        while remaining:
            # Find items with no dependencies (can be executed in parallel)
            ready = [placeholder for placeholder, deps in remaining.items() if not deps]

            if not ready:
                raise ValueError("Circular dependency detected in planner output")

            levels.append(ready)

            # Remove completed items from remaining
            for placeholder in ready:
                remaining.pop(placeholder)

            # Remove completed items from other dependencies
            for ph, deps in list(remaining.items()):
                remaining[ph] = list(set(deps) - set(ready))
        return evidences, levels

    @staticmethod
    def _replace_placeholder(placeholder: str, tool_input: str | dict, tool_output: str | dict) -> str | dict:

        # Replace the placeholders in the tool input with the previous tool output
        if isinstance(tool_input, dict):
            for key, value in tool_input.items():
                if value is not None:
                    if value == placeholder:
                        tool_input[key] = tool_output
                    elif isinstance(value, str) and placeholder in value:
                        # If the placeholder is part of the value, replace it with the stringified output
                        tool_input[key] = value.replace(placeholder, str(tool_output))

        elif isinstance(tool_input, str):
            tool_input = tool_input.replace(placeholder, str(tool_output))

        else:
            assert False, f"Unexpected type for tool_input: {type(tool_input)}"

        return tool_input

    @staticmethod
    def _parse_tool_input(tool_input: str | dict):

        # If the input is already a dictionary, return it as is
        if isinstance(tool_input, dict):
            logger.debug("%s Tool input is already a dictionary. Use the tool input as is.", AGENT_LOG_PREFIX)
            return tool_input

        # If the input is a string, attempt to parse it as JSON
        try:
            tool_input = tool_input.strip()
            # If the input is already a valid JSON string, load it
            tool_input_parsed = json.loads(tool_input)
            logger.debug("%s Successfully parsed structured tool input", AGENT_LOG_PREFIX)

        except JSONDecodeError:
            try:
                # Replace single quotes with double quotes and attempt parsing again
                tool_input_fixed = tool_input.replace("'", '"')
                tool_input_parsed = json.loads(tool_input_fixed)
                logger.debug(
                    "%s Successfully parsed structured tool input after replacing single quotes with double quotes",
                    AGENT_LOG_PREFIX)

            except JSONDecodeError:
                # If it still fails, fall back to using the input as a raw string
                tool_input_parsed = tool_input
                logger.debug("%s Unable to parse structured tool input. Using raw tool input as is.", AGENT_LOG_PREFIX)

        return tool_input_parsed

    async def planner_node(self, state: ReWOOGraphState):
        try:
            logger.debug("%s Starting the ReWOO Planner Node", AGENT_LOG_PREFIX)

            planner = self.planner_prompt | self.llm
            task = str(state.task.content)
            if not task:
                logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
                return {"result": NO_INPUT_ERROR_MESSAGE}
            chat_history = self._get_chat_history(state.messages)
            plan = await self._stream_llm(
                planner,
                {
                    "task": task, "chat_history": chat_history
                },
                RunnableConfig(callbacks=self.callbacks)  # type: ignore
            )

            steps = self._parse_planner_output(str(plan.content))

            # Parse dependencies and create execution levels for parallel processing
            evidence_map, execution_levels = self._parse_planner_dependencies(steps)

            if self.detailed_logs:
                agent_response_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(plan.content))
                logger.info("ReWOO agent planner output: %s", agent_response_log_message)
                logger.info("ReWOO agent execution levels: %s", execution_levels)

            return {
                "plan": plan,
                "evidence_map": evidence_map,
                "execution_levels": execution_levels,
                "current_level": 0,
            }

        except Exception as ex:
            logger.error("%s Failed to call planner_node: %s", AGENT_LOG_PREFIX, ex)
            raise

    async def executor_node(self, state: ReWOOGraphState):
        """
        Execute tools in parallel for the current dependency level.

        This replaces the sequential execution with parallel execution of tools
        that have no dependencies between them.
        """
        try:
            logger.debug("%s Starting the ReWOO Executor Node", AGENT_LOG_PREFIX)

            current_level, level_complete = self._get_current_level_status(state)

            # Should not be invoked if all levels are complete
            if current_level < 0:
                logger.error("%s ReWOO Executor invoked after all levels complete", AGENT_LOG_PREFIX)
                raise RuntimeError("ReWOO Executor invoked after all levels complete")

            # If current level is already complete, move to next level
            if level_complete:
                new_level = current_level + 1
                logger.debug("%s Level %s complete, moving to level %s", AGENT_LOG_PREFIX, current_level, new_level)
                return {"current_level": new_level}

            # Get placeholders for current level
            current_level_placeholders = state.execution_levels[current_level]

            # Filter to only placeholders not yet completed
            pending_placeholders = list(set(current_level_placeholders) - set(state.intermediate_results.keys()))

            if not pending_placeholders:
                # All placeholders in this level are done, move to next level
                new_level = current_level + 1
                return {"current_level": new_level}

            logger.debug("%s Executing level %s with %s tools in parallel: %s",
                         AGENT_LOG_PREFIX,
                         current_level,
                         len(pending_placeholders),
                         pending_placeholders)

            # Execute all tools in current level in parallel
            tasks = []
            for placeholder in pending_placeholders:
                step_info = state.evidence_map[placeholder]
                task = self._execute_single_tool(placeholder, step_info, state.intermediate_results)
                tasks.append(task)

            # Wait for all tasks in current level to complete
            results = await asyncio.gather(*tasks, return_exceptions=True)

            # Process results and update intermediate_results
            updated_intermediate_results = dict(state.intermediate_results)

            for placeholder, result in zip(pending_placeholders, results):
                if isinstance(result, BaseException):
                    logger.error("%s Tool execution failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result)
                    # Create error tool message
                    error_message = f"Tool execution failed: {str(result)}"
                    updated_intermediate_results[placeholder] = ToolMessage(content=error_message,
                                                                            tool_call_id=placeholder)
                    if self.raise_tool_call_error:
                        raise result
                else:
                    updated_intermediate_results[placeholder] = result
                    # Check if the ToolMessage has error status and raise_tool_call_error is True
                    if (isinstance(result, ToolMessage) and hasattr(result, 'status') and result.status == "error"
                            and self.raise_tool_call_error):
                        logger.error("%s Tool call failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result.content)
                        raise RuntimeError(f"Tool call failed: {result.content}")

            if self.detailed_logs:
                logger.info("%s Completed level %s with %s tools",
                            AGENT_LOG_PREFIX,
                            current_level,
                            len(pending_placeholders))

            return {"intermediate_results": updated_intermediate_results}

        except Exception as ex:
            logger.error("%s Failed to call executor_node: %s", AGENT_LOG_PREFIX, ex)
            raise

    async def _execute_single_tool(self,
                                   placeholder: str,
                                   step_info: ReWOOPlanStep,
                                   intermediate_results: dict[str, ToolMessage]) -> ToolMessage:
        """
        Execute a single tool with proper placeholder replacement.

        Args:
            placeholder: The evidence placeholder (e.g., "#E1").
            step_info: Step information containing tool and tool_input.
            intermediate_results: Current intermediate results for placeholder replacement.

        Returns:
            ToolMessage with the tool execution result.
        """
        evidence_info = step_info.evidence
        tool_name = evidence_info.tool
        tool_input = evidence_info.tool_input

        # Replace placeholders in tool input with previous results
        for ph_key, tool_output in intermediate_results.items():
            tool_output_content = tool_output.content
            # If the content is a list, get the first element which should be a dict
            if isinstance(tool_output_content, list):
                tool_output_content = tool_output_content[0]
                assert isinstance(tool_output_content, dict)

            tool_input = self._replace_placeholder(ph_key, tool_input, tool_output_content)

        # Get the requested tool
        requested_tool = self._get_tool(tool_name)
        if not requested_tool:
            configured_tool_names = list(self.tools_dict.keys())
            logger.warning(
                "%s ReWOO Agent wants to call tool %s. In the ReWOO Agent's configuration within the config file,"
                "there is no tool with that name: %s",
                AGENT_LOG_PREFIX,
                tool_name,
                configured_tool_names)

            return ToolMessage(content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=tool_name,
                                                                           tools=configured_tool_names),
                               tool_call_id=placeholder)

        if self.detailed_logs:
            logger.debug("%s Calling tool %s with input: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_input)

        # Parse and execute the tool
        tool_input_parsed = self._parse_tool_input(tool_input)
        tool_response = await self._call_tool(
            requested_tool,
            tool_input_parsed,
            RunnableConfig(callbacks=self.callbacks),  # type: ignore
            max_retries=self.tool_call_max_retries)

        if self.detailed_logs:
            self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))

        return tool_response

    async def solver_node(self, state: ReWOOGraphState):
        try:
            logger.debug("%s Starting the ReWOO Solver Node", AGENT_LOG_PREFIX)

            plan = ""
            # Add the tool outputs of each step to the plan using evidence_map
            for placeholder, step_info in state.evidence_map.items():
                evidence_info = step_info.evidence
                original_tool_input = evidence_info.tool_input
                tool_name = evidence_info.tool

                # Replace placeholders in tool input with actual results
                final_tool_input = original_tool_input
                for ph_key, tool_output in state.intermediate_results.items():
                    tool_output_content = tool_output.content
                    # If the content is a list, get the first element which should be a dict
                    if isinstance(tool_output_content, list):
                        tool_output_content = tool_output_content[0]
                        assert isinstance(tool_output_content, dict)

                    final_tool_input = self._replace_placeholder(ph_key, final_tool_input, tool_output_content)

                # Get the final result for this placeholder
                final_result = ""
                if placeholder in state.intermediate_results:
                    result_content = state.intermediate_results[placeholder].content
                    if isinstance(result_content, list):
                        result_content = result_content[0]
                        if isinstance(result_content, dict):
                            final_result = str(result_content)
                    else:
                        final_result = str(result_content)

                step_plan = step_info.plan
                plan += '\n'.join([
                    f"Plan: {step_plan}",
                    f"{placeholder} = {tool_name}[{final_tool_input}",
                    f"Result: {final_result}\n\n"
                ])

            task = str(state.task.content)
            solver_prompt = self.solver_prompt.partial(plan=plan)
            solver = solver_prompt | self.llm

            output_message = await self._stream_llm(solver, {"task": task},
                                                    RunnableConfig(callbacks=self.callbacks))  # type: ignore

            if self.detailed_logs:
                solver_output_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(output_message.content))
                logger.info("ReWOO agent solver output: %s", solver_output_log_message)

            return {"result": output_message}

        except Exception as ex:
            logger.error("%s Failed to call solver_node: %s", AGENT_LOG_PREFIX, ex)
            raise

    async def conditional_edge(self, state: ReWOOGraphState):
        try:
            logger.debug("%s Starting the ReWOO Conditional Edge", AGENT_LOG_PREFIX)

            current_level, level_complete = self._get_current_level_status(state)

            # If all levels are complete, move to solver
            if current_level == -1:
                logger.debug("%s All execution levels complete, moving to solver", AGENT_LOG_PREFIX)
                return AgentDecision.END

            # If current level is complete, check if there are more levels
            if level_complete:
                next_level = current_level + 1
                if next_level >= len(state.execution_levels):
                    logger.debug("%s All execution levels complete, moving to solver", AGENT_LOG_PREFIX)
                    return AgentDecision.END

            logger.debug("%s Continuing with executor (level %s, complete: %s)",
                         AGENT_LOG_PREFIX,
                         current_level,
                         level_complete)
            return AgentDecision.TOOL

        except Exception as ex:
            logger.exception("%s Failed to determine whether agent is calling a tool: %s", AGENT_LOG_PREFIX, ex)
            logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
            return AgentDecision.END

    async def _build_graph(self, state_schema: type) -> CompiledStateGraph:
        try:
            logger.debug("%s Building and compiling the ReWOO Graph", AGENT_LOG_PREFIX)

            graph = StateGraph(state_schema)
            graph.add_node("planner", self.planner_node)
            graph.add_node("executor", self.executor_node)
            graph.add_node("solver", self.solver_node)

            graph.add_edge("planner", "executor")
            graph.add_conditional_edges("executor",
                                        self.conditional_edge, {
                                            AgentDecision.TOOL: "executor", AgentDecision.END: "solver"
                                        })

            graph.set_entry_point("planner")
            graph.set_finish_point("solver")

            self.graph = graph.compile()
            logger.debug("%s ReWOO Graph built and compiled successfully", AGENT_LOG_PREFIX)

            return self.graph

        except Exception as ex:
            logger.error("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex)
            raise

    async def build_graph(self):
        try:
            await self._build_graph(state_schema=ReWOOGraphState)
            logger.debug("%s ReWOO Graph built and compiled successfully", AGENT_LOG_PREFIX)
            return self.graph
        except Exception as ex:
            logger.error("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex)
            raise

    @staticmethod
    def validate_planner_prompt(planner_prompt: str) -> bool:
        errors = []
        if not planner_prompt:
            errors.append("The planner prompt cannot be empty.")
        required_prompt_variables = {
            "{tools}": "The planner prompt must contain {tools} so the planner agent knows about configured tools.",
            "{tool_names}": "The planner prompt must contain {tool_names} so the planner agent knows tool names."
        }
        for variable_name, error_message in required_prompt_variables.items():
            if variable_name not in planner_prompt:
                errors.append(error_message)
        if errors:
            error_text = "\n".join(errors)
            logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
            raise ValueError(error_text)
        return True

    @staticmethod
    def validate_solver_prompt(solver_prompt: str) -> bool:
        errors = []
        if not solver_prompt:
            errors.append("The solver prompt cannot be empty.")
        if errors:
            error_text = "\n".join(errors)
            logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
            raise ValueError(error_text)
        return True
